---
title: Optimizing matrix multiplication
description: Consider an algorithm for multiplying matrices using three nested loops. The complexity of such an algorithm by definition should be O(n³), but there are...
sections: [Permutations,Nested loops,Comparing algorithms]
tags: [java,arrays,multidimensional arrays,matrices,rows,columns,layers,loops]
canonical_url: /en/2021/12/10/optimizing-matrix-multiplication.html
url_translated: /ru/2021/12/09/optimizing-matrix-multiplication.html
title_translated: Оптимизация умножения матриц
date: 2021.12.10
lang: en
---

Consider an algorithm for multiplying matrices using three nested loops. The complexity of such
an algorithm by definition should be `O(n³)`, but there are particularities related to the execution
environment — the speed of the algorithm depends on the sequence in which the loops are executed.

Let's compare different permutations of nested loops and the execution time of the algorithms.
Let's take two matrices: {`L×M`} and {`M×N`} &rarr; three loops &rarr; six permutations:
`LMN`, `LNM`, `MLN`, `MNL`, `NLM`, `NML`.

The algorithms that work faster than others are those that write data to the resulting matrix
*row-wise in layers*: `LMN` and `MLN`, — the percentage difference to other algorithms is substantial
and depends on the execution environment.

*Further optimization: [Matrix multiplication in parallel streams]({{ '/en/2022/02/09/matrix-multiplication-parallel-streams.html' | relative_url }}).*

{% include heading.html text="Row-wise algorithm" hash="row-wise-algorithm" %}

The outer loop bypasses the rows of the first matrix `L`, then there is a loop across the *common side*
of the two matrices `M` and it is followed by a loop across the columns of the second matrix `N`.
Writing to the resulting matrix occurs row-wise, and each row is filled in layers.

```java
/**
 * @param l rows of matrix 'a'
 * @param m columns of matrix 'a'
 *          and rows of matrix 'b'
 * @param n columns of matrix 'b'
 * @param a first matrix 'l×m'
 * @param b second matrix 'm×n'
 * @return resulting matrix 'l×n'
 */
public static int[][] matrixMultiplicationLMN(int l, int m, int n, int[][] a, int[][] b) {
    // resulting matrix
    int[][] c = new int[l][n];
    // bypass the indexes of the rows of matrix 'a'
    for (int i = 0; i < l; i++)
        // bypass the indexes of the common side of two matrices:
        // the columns of matrix 'a' and the rows of matrix 'b'
        for (int k = 0; k < m; k++)
            // bypass the indexes of the columns of matrix 'b'
            for (int j = 0; j < n; j++)
                // the sum of the products of the elements of the i-th
                // row of matrix 'a' and the j-th column of matrix 'b'
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
```

{% include heading.html text="Layer-wise algorithm" hash="layer-wise-algorithm" %}

The outer loop bypasses the *common side* of the two matrices `M`, then there is a loop across the rows
of the first matrix `L`, and it is followed by a loop across the columns of the second matrix `N`.
Writing to the resulting matrix occurs layer-wise, and each layer is filled row-wise.

```java
/**
 * @param l rows of matrix 'a'
 * @param m columns of matrix 'a'
 *          and rows of matrix 'b'
 * @param n columns of matrix 'b'
 * @param a first matrix 'l×m'
 * @param b second matrix 'm×n'
 * @return resulting matrix 'l×n'
 */
public static int[][] matrixMultiplicationMLN(int l, int m, int n, int[][] a, int[][] b) {
    // resulting matrix
    int[][] c = new int[l][n];
    // bypass the indexes of the common side of two matrices:
    // the columns of matrix 'a' and the rows of matrix 'b'
    for (int k = 0; k < m; k++)
        // bypass the indexes of the rows of matrix 'a'
        for (int i = 0; i < l; i++)
            // bypass the indexes of the columns of matrix 'b'
            for (int j = 0; j < n; j++)
                // the sum of the products of the elements of the i-th
                // row of matrix 'a' and the j-th column of matrix 'b'
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
```

{% include heading.html text="Other algorithms" hash="other-algorithms" type="3" %}

The bypass of the columns of the second matrix `N` occurs before the bypass of the *common side* of the
two matrices `M` and/or before the bypass of the rows of the first matrix `L`.

{% capture collapsed_md %}
```java
public static int[][] matrixMultiplicationLNM(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int i = 0; i < l; i++)
        for (int j = 0; j < n; j++)
            for (int k = 0; k < m; k++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
```
```java
public static int[][] matrixMultiplicationNLM(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int j = 0; j < n; j++)
        for (int i = 0; i < l; i++)
            for (int k = 0; k < m; k++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
```
```java
public static int[][] matrixMultiplicationMNL(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int k = 0; k < m; k++)
        for (int j = 0; j < n; j++)
            for (int i = 0; i < l; i++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
```
```java
public static int[][] matrixMultiplicationNML(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int j = 0; j < n; j++)
        for (int k = 0; k < m; k++)
            for (int i = 0; i < l; i++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
```
{% endcapture %}
{%- include collapsed_block.html summary="Code without comments" content=collapsed_md -%}

{% include heading.html text="Comparing algorithms" hash="comparing-algorithms" %}

To check, we take two matrices `A=[500×700]` and `B=[700×450]`, filled with random numbers. First, we
compare the correctness of the implementation of the algorithms — all results obtained must match.
Then we execute each method 10 times and calculate the average execution time.

```java
// start the program and output the result
public static void main(String[] args) throws Exception {
    // incoming data
    int l = 500, m = 700, n = 450, steps = 10;
    int[][] a = randomMatrix(l, m), b = randomMatrix(m, n);
    // map of methods for comparison
    var methods = new TreeMap<String, Callable<int[][]>>(Map.of(
            "LMN", () -> matrixMultiplicationLMN(l, m, n, a, b),
            "LNM", () -> matrixMultiplicationLNM(l, m, n, a, b),
            "MLN", () -> matrixMultiplicationMLN(l, m, n, a, b),
            "MNL", () -> matrixMultiplicationMNL(l, m, n, a, b),
            "NLM", () -> matrixMultiplicationNLM(l, m, n, a, b),
            "NML", () -> matrixMultiplicationNML(l, m, n, a, b)));
    int[][] last = null;
    // bypass the methods map, check the correctness of the returned
    // results, all results obtained must be equal to each other
    for (var method : methods.entrySet()) {
        // next method for comparison
        var next = methods.higherEntry(method.getKey());
        // if the current method is not the last — compare the results of two methods
        if (next != null) System.out.println(method.getKey() + "=" + next.getKey() + ": "
                // compare the result of executing the current method and the next one
                + Arrays.deepEquals(method.getValue().call(), next.getValue().call()));
            // the result of the last method
        else last = method.getValue().call();
    }
    int[][] test = last;
    // bypass the methods map, measure the execution time of each method
    for (var method : methods.entrySet())
        // parameters: title, number of steps, runnable code
        benchmark(method.getKey(), steps, () -> {
            try { // execute the method, get the result
                int[][] result = method.getValue().call();
                // check the correctness of the results at each step
                if (!Arrays.deepEquals(result, test)) System.out.print("error");
            } catch (Exception e) {
                e.printStackTrace();
            }
        });
}
```

{% capture collapsed_md %}
```java
// helper method, returns a matrix of the specified size
private static int[][] randomMatrix(int row, int col) {
    int[][] matrix = new int[row][col];
    for (int i = 0; i < row; i++)
        for (int j = 0; j < col; j++)
            matrix[i][j] = (int) (Math.random() * row * col);
    return matrix;
}
```
```java
// helper method for measuring the execution time of the passed code
private static void benchmark(String title, int steps, Runnable runnable) {
    long time, avg = 0;
    System.out.print(title);
    for (int i = 0; i < steps; i++) {
        time = System.currentTimeMillis();
        runnable.run();
        time = System.currentTimeMillis() - time;
        // execution time of one step
        System.out.print(" | " + time);
        avg += time;
    }
    // average execution time
    System.out.println(" || " + (avg / steps));
}
```
{% endcapture %}
{%- include collapsed_block.html summary="Helper methods" content=collapsed_md -%}

Output depends on the execution environment, time in milliseconds:

```
LMN=LNM: true
LNM=MLN: true
MLN=MNL: true
MNL=NLM: true
NLM=NML: true
LMN | 191 | 109 | 105 | 106 | 105 | 106 | 106 | 105 | 123 | 109 || 116
LNM | 417 | 418 | 419 | 416 | 416 | 417 | 418 | 417 | 416 | 417 || 417
MLN | 113 | 115 | 113 | 115 | 114 | 114 | 114 | 115 | 114 | 113 || 114
MNL | 857 | 864 | 857 | 859 | 860 | 863 | 862 | 860 | 858 | 860 || 860
NLM | 404 | 404 | 407 | 404 | 406 | 405 | 405 | 404 | 403 | 404 || 404
NML | 866 | 872 | 867 | 868 | 867 | 868 | 867 | 873 | 869 | 863 || 868
```

All the methods described above, including collapsed blocks, can be placed in one class.

{% capture collapsed_md %}
```java
import java.util.Arrays;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.Callable;
```
{% endcapture %}
{%- include collapsed_block.html summary="Required imports" content=collapsed_md -%}
